// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef GTEST
#include "gtest/gtest.h"
#endif
#include "test/common/AcceleratorHarness.h"
#include "test/common/GoldModel.h"
#include "test/common/Utils.h"

void run_simple_test(Params params) {
  INPUT_DATATYPE *mainMemory = new INPUT_DATATYPE[4 * 1024 * 1024];

  // Create matrix A
  INPUT_DATATYPE *matrixA =
      new INPUT_DATATYPE[params.M0 * params.M1 * params.N1 * DIMENSION];
  for (int i = 0; i < params.M0 * params.M1; i++) {
    for (int j = 0; j < params.N1 * DIMENSION; j++) {
      // int val = rand() % 128;
      int val = i * 10 + j;

      mainMemory[params.INPUT_OFFSET + i * (params.N1 * DIMENSION) + j] = val;
      matrixA[i * (params.N1 * DIMENSION) + j] = val;
    }
  }

  // Create matrix B
  INPUT_DATATYPE *matrixB =
      new INPUT_DATATYPE[params.N1 * DIMENSION * params.P1 * params.P2 *
                         DIMENSION];
  for (int i = 0; i < params.N1 * DIMENSION; i++) {
    for (int j = 0; j < params.P1 * params.P2 * DIMENSION; j++) {
      // int val = rand() % 128;
      int val = i;

      mainMemory[params.WEIGHT_OFFSET +
                 i * (params.P1 * params.P2 * DIMENSION) + j] = val;
      matrixB[i * (params.P1 * params.P2 * DIMENSION) + j] = val;
    }
  }

  // Create matrix C
  OUTPUT_DATATYPE *matrixC =
      new OUTPUT_DATATYPE[params.M0 * params.M1 * params.P1 * params.P2 *
                          DIMENSION];

  run_op(params, mainMemory);
  run_gold_op(params, matrixA, matrixB, matrixC);
  compare_arrays(&mainMemory[params.OUTPUT_OFFSET], matrixC,
                 params.M0 * params.M1 * params.P1 * params.P2 * DIMENSION);
  // check_outputs(params, mainMemory, matrixC);

  delete[] matrixA;
  delete[] matrixB;
  delete[] matrixC;
  delete[] mainMemory;
}

#ifdef GTEST
TEST(Simple, Basic) {
#else
void basic() {
#endif
  std::cout << "Basic Test" << std::endl;
  std::cout << "----------" << std::endl;

  const Params params = {
      16,                       // M0
      2,                        // P1
      2,                        // N1
      1,                        // M1
      1,                        // P2
      0,                        // INPUT_OFFSET
      1024 * 1024,              // WEIGHT_OFFSET
      2 * 1024 * 1024,          // OUTPUT_OFFSET
      false,                    // SOFTMAX
      1,                        // SCALE
      false,                    // TRANSPOSE
      0,                        // VECTOR_OFFSET
      false,                    // VEC_OP
      false,                    // VEC_SUB
      false,                    // VEC_SQUARE
      false,                    // VEC_REDUCE
      true,                     // CONST_SCALE
      0,                        // VEC_SCALE_OFFSET
      0,                        // VEC_SUB_OFFSET
      false,                    // RELU
      {{1, 1, 1}, {2, 2, 16}},  // loops
      {1, 2},                   // input
      {2, 0},                   // reduction
      {0, 1}                    // weight
  };

  run_simple_test(params);
}

#ifdef GTEST
TEST(Simple, HeadSplit) {
#else
void headsplit() {
#endif
  std::cout << "Head Split Test" << std::endl;
  std::cout << "----------" << std::endl;

  const Params params = {
      16,                       // M0
      2,                        // P1
      2,                        // N1
      1,                        // M1
      1,                        // P2
      0,                        // INPUT_OFFSET
      1024 * 1024,              // WEIGHT_OFFSET
      2 * 1024 * 1024,          // OUTPUT_OFFSET
      false,                    // SOFTMAX
      1,                        // SCALE
      false,                    // TRANSPOSE
      0,                        // VECTOR_OFFSET
      false,                    // VEC_OP
      false,                    // VEC_SUB
      false,                    // VEC_SQUARE
      false,                    // VEC_REDUCE
      true,                     // CONST_SCALE
      0,                        // VEC_SCALE_OFFSET
      0,                        // VEC_SUB_OFFSET
      false,                    // RELU
      {{1, 1, 1}, {2, 2, 16}},  // loops
      {1, 2},                   // input
      {2, 0},                   // reduction
      {0, 1},                   // weight
      false,                    // CONCAT_HEAD
      true,                     // SPLIT_HEAD
      2,                        // HEAD_SZ_LG2
  };

  run_simple_test(params);
}

#ifdef GTEST
TEST(Simple, HeadConcat) {
#else
void headconcat() {
#endif
  std::cout << "Head Concat Test" << std::endl;
  std::cout << "----------" << std::endl;

  const Params params = {
      16,                       // M0
      2,                        // P1
      2,                        // N1
      1,                        // M1
      1,                        // P2
      0,                        // INPUT_OFFSET
      1024 * 1024,              // WEIGHT_OFFSET
      2 * 1024 * 1024,          // OUTPUT_OFFSET
      false,                    // SOFTMAX
      1,                        // SCALE
      false,                    // TRANSPOSE
      0,                        // VECTOR_OFFSET
      false,                    // VEC_OP
      false,                    // VEC_SUB
      false,                    // VEC_SQUARE
      false,                    // VEC_REDUCE
      true,                     // CONST_SCALE
      0,                        // VEC_SCALE_OFFSET
      0,                        // VEC_SUB_OFFSET
      false,                    // RELU
      {{1, 1, 1}, {2, 2, 16}},  // loops
      {1, 2},                   // input
      {2, 0},                   // reduction
      {0, 1},                   // weight
      true,                     // CONCAT_HEAD
      false,                    // SPLIT_HEAD
      2,                        // HEAD_SZ_LG2
  };

  run_simple_test(params);
}

int sc_main(int argc, char *argv[]) {
#ifdef GTEST
  ::testing::InitGoogleTest(&argc, argv);
  return RUN_ALL_TESTS();
#else
  basic();
  headsplit();
  headconcat();
#endif
}
